import argparse

"""
Here are the param for the training

"""


def get_common_args():
    parser = argparse.ArgumentParser()
    # the environment setting
    parser.add_argument('--difficulty', type=str, default='7', help='the difficulty of the game')
    parser.add_argument('--game_version', type=str, default='latest', help='the version of the game')
    parser.add_argument('--map', type=str, default='3_vs_2', help='the map of the game')
    parser.add_argument('--step_mul', type=int, default=8, help='how many steps to make an action')
    parser.add_argument('--replay_dir', type=str, default='', help='absolute path to save the replay')
    # The alternative algorithms are vdn, coma, central_v, qmix, qtran_base,
    # qtran_alt, reinforce, coma+commnet, central_v+commnet, reinforqce+commnet，
    # coma+g2anet, central_v+g2anet, reinforce+g2anet, maven
    parser.add_argument('--alg', type=str, default='qmix', help='the algorithm to train the agent')#reinforce+g2anet
    parser.add_argument('--n_steps', type=int, default=2000000, help='total time steps')
    parser.add_argument('--n_episodes', type=int, default=1, help='the number of episodes before once training')
    parser.add_argument('--last_action', type=bool, default=True, help='whether to use the last action to choose action')
    parser.add_argument('--reuse_network', type=bool, default=True, help='whether to use one network for all agents')
    parser.add_argument('--gamma', type=float, default=0.99, help='discount factor')
    parser.add_argument('--optimizer', type=str, default="RMS", help='optimizer')
    parser.add_argument('--evaluate_cycle', type=int, default=5000, help='how often to evaluate the model')
    parser.add_argument('--evaluate_epoch', type=int, default=20, help='number of the epoch to evaluate the agent')
    parser.add_argument('--model_dir', type=str, default='./model', help='model directory of the policy')
    parser.add_argument('--result_dir', type=str, default='./result', help='result directory of the policy')
    parser.add_argument('--load_model', type=bool, default=False, help='whether to load the pretrained model')
    parser.add_argument('--evaluate', type=bool, default=False, help='whether to evaluate the model')
    parser.add_argument('--cuda', type=bool, default=True, help='whether to use the GPU')
    parser.add_argument('--burn_in_period', type=int, default=10)
    parser.add_argument('--render', type=bool, default=False,
                        help='whether to train the model')
    parser.add_argument('--project_name', action='store_true', default='RLFuture')

    #parser.add_argument('--env', type=str, default='MMM2')
    #parser.add_argument('--env', type=str, default='3s5z_vs_3s6z')
    #parser.add_argument('--env', type=str, default='3s5z_vs_3s7z')
    parser.add_argument('--env', type=str, default='corridor')
    #parser.add_argument('--env', type=str, default='3s_vs_8z')

    #parser.add_argument('--env', type=str, default='MMM3')
    #parser.add_argument('--env', type=str, default='MMM4')

    #parser.add_argument('--env', type=str, default='3m')
    #parser.add_argument('--env', type=str, default='3s_vs_5z')


    #parser.add_argument('--env', type=str, default='6h_vs_8z')
    # parser.add_argument('--env', type=str, default='10m_vs_11m')
    #parser.add_argument('--env', type=str, default='12m_vs_14m')


    # parser.add_argument('--env', type=str, default='pacmen')
    # parser.add_argument('--m', type=int, default=1, help='GreenBean')
    # parser.add_argument('--c', type=float, default=0, help='GreenReward')

    #parser.add_argument('--norm', type=bool, default=False, help='whether to use the norm')

    #parser.add_argument('--env', type=str, default='2_vs_2')
    #parser.add_argument('--env', type=str, default='3_vs_2')

    #parser.add_argument('--env', type=str, default='3_vs_3')
    #parser.add_argument('--env', type=str, default='3_vs_4_BackBall')
    #parser.add_argument('--env', type=str, default='counterattack')

    #parser.add_argument('--env', type=str, default='simple_spread_n3.py')s

    parser.add_argument('--uRNN', type=bool, default=True, help='whether to use utility network')
    parser.add_argument('--wandb', type=bool, default=True, help='whether to use the wandb')
    parser.add_argument('--batch_size', type=int, default=32, help='batch_size')
    parser.add_argument('--seed', type=int, default=1252, help='random seed')
    parser.add_argument('--label', type=str, default='11.01')


    parser.add_argument('--beta1', type=float, default=1)  # .5
    parser.add_argument('--beta2', type=float, default=1)  # 2. q
    parser.add_argument('--beta', type=float, default=.01)
    parser.add_argument('--td_lambda', type=float, default=.6)


    parser.add_argument('--beta3', type=float, default=0)
    parser.add_argument('--beta4', type=float, default=0.05)#*
    parser.add_argument('--beta5', type=float, default=1)

    parser.add_argument('--start_anneal_time', type=int, default=6.5e7)#
    parser.add_argument('--anneal_rate', type=float, default=.5)
    parser.add_argument('--anneal_type', type=str, default='linear')




    args = parser.parse_args()
    return args


# arguments of coma
def get_coma_args(args):
    # network
    args.rnn_hidden_dim = 64
    args.critic_dim = 128
    args.lr_actor = 1e-4
    args.lr_critic = 1e-3

    # epsilon-greedy
    args.epsilon = 0.5
    args.anneal_epsilon = 0.00064
    args.min_epsilon = 0.02
    args.epsilon_anneal_scale = 'episode'

    # lambda of td-lambda return
    args.td_lambda = 0.8

    # how often to save the model
    args.save_cycle = 5000

    # how often to update the target_net
    args.target_update_cycle = 200

    # prevent gradient explosion
    args.grad_norm_clip = 10

    return args


# arguments of vnd、 qmix、 qtran
def get_mixer_args(args):
    # network
    args.rnn_hidden_dim = 64
    args.qmix_hidden_dim = 32
    args.hyper_hidden_dim = 64
    args.qtran_hidden_dim = 64
    args.lr = 5e-4
    if args.env=='corridor' or args.env=='3s_vs_8z':
        args.two_hyper_layers = False
    else:
        args.two_hyper_layers = True

    # epsilon greedy
    args.epsilon = 1
    args.min_epsilon = 0.05
    anneal_steps = 50000
    args.anneal_epsilon = (args.epsilon - args.min_epsilon) / anneal_steps
    args.epsilon_anneal_scale = 'step'

    # the number of the train steps in one epoch
    args.train_steps = 1

    # experience replay
    #args.batch_size = 32
    args.buffer_size = int(5e3)

    # how often to save the model
    args.save_cycle = 5000

    # how often to update the target_net
    args.target_update_cycle = 200

    # QTRAN lambda
    args.lambda_opt = 1
    args.lambda_nopt = 1

    # prevent gradient explosion
    args.grad_norm_clip = 10

    # MAVEN
    args.noise_dim = 16
    args.lambda_mi = 0.001
    args.lambda_ql = 1
    args.entropy_coefficient = 0.001
    return args


# arguments of central_v
def get_centralv_args(args):
    # network
    args.rnn_hidden_dim = 64
    args.critic_dim = 128
    args.lr_actor = 1e-4
    args.lr_critic = 1e-3

    # epsilon-greedy
    args.epsilon = 0.5
    args.anneal_epsilon = 0.00064
    args.min_epsilon = 0.02
    args.epsilon_anneal_scale = 'episode'

    # lambda of td-lambda return
    args.td_lambda = 0.8

    # how often to save the model
    args.save_cycle = 5000

    # how often to update the target_net
    args.target_update_cycle = 200

    # prevent gradient explosion
    args.grad_norm_clip = 10

    return args


# arguments of central_v
def get_reinforce_args(args):
    # network
    args.rnn_hidden_dim = 64
    args.critic_dim = 128
    args.lr_actor = 1e-4
    args.lr_critic = 1e-3

    # epsilon-greedy
    args.epsilon = 0.5
    args.anneal_epsilon = 0.00064
    args.min_epsilon = 0.02
    args.epsilon_anneal_scale = 'episode'

    # how often to save the model
    args.save_cycle = 5000

    # prevent gradient explosion
    args.grad_norm_clip = 10

    return args


# arguments of coma+commnet
def get_commnet_args(args):
    if args.map == '3m':
        args.k = 2
    else:
        args.k = 3
    return args


def get_g2anet_args(args):
    args.attention_dim = 32
    args.hard = True
    return args

def get_qplex_args(args):
    # network
    args.rnn_hidden_dim = 64
    args.qmix_hidden_dim = 64
    args.two_hyper_layers = False
    args.hyper_hidden_dim = 64
    args.qtran_hidden_dim = 64
    args.lr = 5e-4

    args.epsilon_start = 1
    args.epsilon_finish = 0.05

    if args.QPLEX_mixer == "dmaq":
        args.weighted_head = True

    if args.env == '3_vs_2':
        args.alpha = 0.1
        args.epsilon_anneal_time = 50000
    elif args.env == '4_vs_3' and args.alg == 'CDS':
        args.alpha = 0.8


    if 'pacmen' in args.env:
        args.QPLEX_mixer = 'dmaq'
        args.n_epoch = 70000
    elif args.env == '3_vs_2' or args.env == '2_vs_3':
        args.n_epoch = 130000
    elif args.env == '3_vs_3_full' or  args.env == '3_vs_5_full'or args.env=='counterattack':
        args.n_epoch = 600000
    else:
        args.n_epoch = 200000

    args.n_episodes = 1
    args.evaluate_cycle = 200

    args.buffer_size = int(5e3)
    args.save_cycle = 1000
    args.optim_alpha = 0.99
    args.optim_eps = 0.00001
    args.target_update_cycle = 200
    args.grad_norm_clip = 10
    args.num_kernel = args.n_head
    return args
